import torch
from torch.optim import Adam as Optimizer
from UPMSPModel import *
from UPMSPEnv import * 

class UPMSP_Trainer:
    def __init__(self,
                 env_params,
                 model_params,
                 optimizer_params,
                 trainer_params):

        # save arguments 
        self.env_params = env_params
        self.model_params = model_params
        self.optimizer_params = optimizer_params
        self.trainer_params = trainer_params

        # Main Components
        self.model = UPMSP_Model(self.model_params).to('cuda')
        self.optimizer = Optimizer(self.model.parameters(), **self.optimizer_params['optimizer'])
        self.latent_cont_dim = self.model_params['latent_cont_dim']
        self.latent_disc_dim = self.model_params['latent_disc_dim']

    def run(self):
        if self.env_params['mode']  == "fine_tuning":
            self.model.load_state_dict(torch.load(self.env_params['fine_tuning_path']))
        
        num_local = self.trainer_params['num_local']
        accumulation_step = self.trainer_params['accumulation']
        for epoch in range(0, self.trainer_params['epochs']+1):
            print("Epoch: ", epoch)
            self.model.train()
            for i in range(num_local):
                env_1 = Parallel_machine_tardiness(self.env_params)
                env_1._reset(self.env_params['batch_size'], self.env_params['pomo_size'])
                latent_c_var = torch.empty(self.env_params['batch_size'], self.env_params['pomo_size'], self.latent_cont_dim ).uniform_(-1, 1)

                latent_d_var = torch.zeros((self.env_params['batch_size'], self.env_params['pomo_size'], self.latent_disc_dim), dtype=torch.float32)
                one_hot_idx = torch.randint(0, self.latent_disc_dim, (self.env_params['batch_size'], self.env_params['pomo_size']), dtype=torch.long)
                latent_d_var[torch.arange(self.env_params['batch_size']).unsqueeze(1), torch.arange(self.env_params['pomo_size'],).unsqueeze(0), one_hot_idx] = 1

                latent_var = torch.cat([latent_d_var, latent_c_var], dim=-1)
                latent_var = latent_var.reshape(self.env_params['batch_size']*self.env_params['pomo_size'],-1).to('cuda')

                done = False
                s = env_1._get_state()
                log_prob_tmp = torch.zeros(size=(self.env_params['batch_size'], self.env_params['pomo_size'], 0)).to('cuda')
                while done == False:
                    action, log_prob = self.model.get_action(s, latent_var)
                    log_prob = log_prob.reshape(self.env_params['batch_size'], self.env_params['pomo_size'])
                    log_prob_tmp = torch.cat((log_prob_tmp, log_prob[:, :, None]), dim=2)
                    s, r, done = env_1._step(action)

                reward = r.reshape(self.env_params['batch_size'], self.env_params['pomo_size']).to('cuda')
                _, argmax = reward.max(dim=1)
                max_reward = reward.max(dim=1, keepdim=True).values  # [batch, 1]
                mean_reward = reward.mean(dim=1, keepdim=True) 
                pomo_variance = reward.var(dim=1, keepdim=True, unbiased=False)  # [batch, 1]
                loss_weight = (max_reward - mean_reward) / torch.sqrt(pomo_variance + 1e-8).to('cuda')  # [batch, 1]

                probs = log_prob_tmp[torch.arange(self.env_params['batch_size']), argmax, :] 
                log_probs = probs[:,:-1]
                batch_loss = log_probs*loss_weight
                loss = -batch_loss.mean()
                loss = loss/accumulation_step

                loss.backward()
                if (i+1)%accumulation_step==0:
                    self.optimizer.step()
                    self.optimizer.zero_grad()
            torch.save(self.model.state_dict(), "./result/new_checkpoint-"+str(epoch)+".pt")
